Unthunk tangents (if any) before returning gradient#1551
Unthunk tangents (if any) before returning gradient#1551CarloLucibello merged 6 commits intoFluxML:masterfrom pxl-th:pxl-th/thunk
Conversation
|
Can you come up with a test? |
|
Hm... While it does fix the issue with using Zygote
struct Dense
w::Matrix{Float32}
end
(d::Dense)(x) = d.w * x
function main()
layers = [Dense(rand(Float32, 3, 3))]
x = ones(Float32, 3)
g = gradient(layers -> sum(layers[1](x)), layers)[1]
@show g
end
main()Changing So maybe we unthunk the before returning? |
|
Can you remind me what if anything we lose by unthunking at the top level (before |
|
I think what's happening is that I agree that making It might be worth having |
|
Hm... it was taught, |
|
I don't think it is imported or used, per the comment in the code diff I linked. |
|
It is imported here. |
|
That's strange, because the code in #966 is definitely defining an |
|
I added tests, but I'm a bit out of ideas, so I made it unthunk before returning gradients. |
|
I've tested with Flux and all tests pass (CPU + AMDGPU). Maybe this is fine for now? |
|
The current approach LGTM, but perhaps it would make sense to have the new overloads for |
|
Agree, moved them: FluxML/ZygoteRules.jl#28 |
Based on FluxML/ZygoteRules.jl#28 (comment) and other comments, maybe it'd be better to move on with this PR and have a separate PR that will resolve this. |
|
It also turns out that defining |
|
ToucheSir
left a comment
There was a problem hiding this comment.
In hindsight, #966 should not have been merged with the unthunk_tangent changes because it was accidentally committing type piracy. But undoing that is going to take some work, and it's not clear if ZygoteRules.jl will still exist in its current form by the time someone gets around to doing said work because of how complex the phasing needs to be.
Project.toml
Outdated
| Statistics = "1" | ||
| Tracker = "0.2" | ||
| ZygoteRules = "0.2.5" | ||
| ZygoteRules = "=0.2.5" |
There was a problem hiding this comment.
This bound needs to be updated after FluxML/ZygoteRules.jl#31.
Fixes: FluxML/Flux.jl#2574 (comment)